from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import tensorflow as tf
import numpy as np

import sys
import time

import sgan_model as sgan

sys.path.append('../')
import image_utils as iu
from datasets import MNISTDataSet as DataSet


results = {
    'output': './gen_img/',
    'checkpoint': './model/checkpoint',
    'model': './model/SGAN-model.ckpt'
}

train_step = {
    'global_step': 250001,
    'logging_interval': 2500,
}


def main():
    start_time = time.time()  # Clocking start

    # MNIST Dataset load
    mnist = DataSet(ds_path="./").data

    # GPU configure
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as s:
        # SGAN Model
        model = sgan.SGAN(s)

        # Initializing
        s.run(tf.global_variables_initializer())

        sample_x, sample_y = mnist.test.next_batch(model.sample_num)
        # sample_x = np.reshape(sample_x, [model.sample_num, model.n_input])

        d_overpowered = False
        for step in range(train_step['global_step']):
            batch_x, batch_y = mnist.train.next_batch(model.batch_size)
            # batch_x = np.reshape(batch_x, [model.batch_size, model.n_input])
            batch_z_0 = np.random.uniform(-1., 1., [model.batch_size, model.z_dim]).astype(np.float32)
            batch_z_1 = np.random.uniform(-1., 1., [model.batch_size, model.z_dim]).astype(np.float32)

            # Update D network
            if not d_overpowered:
                _, d_0_loss, _, _ = s.run([model.d_0_op, model.d_0_loss, model.d_1_op, model.d_1_loss],
                                          feed_dict={
                                              model.x: batch_x,
                                              model.y: batch_y,
                                              model.z_1: batch_z_1,
                                              model.z_0: batch_z_0,
                                          })

            # Update G network
            _, g_0_loss, _, _ = s.run([model.g_0_op, model.g_0_loss, model.g_1_op, model.g_1_loss],
                                      feed_dict={
                                          model.x: batch_x,
                                          model.y: batch_y,
                                          model.z_1: batch_z_1,
                                          model.z_0: batch_z_0,
                                      })

            d_overpowered = d_0_loss < g_0_loss / 2

            if step % train_step['logging_interval'] == 0:
                batch_x, batch_y = mnist.train.next_batch(model.batch_size)
                # batch_x = np.reshape(batch_x, [model.batch_size, model.n_input])
                batch_z_0 = np.random.uniform(-1., 1., [model.batch_size, model.z_dim]).astype(np.float32)
                batch_z_1 = np.random.uniform(-1., 1., [model.batch_size, model.z_dim]).astype(np.float32)

                d_0_loss, _, g_0_loss, _, summary = s.run([model.d_0_loss, model.d_1_loss,
                                                           model.g_0_loss, model.g_1_loss,
                                                           model.merged],
                                                          feed_dict={
                                                              model.x: batch_x,
                                                              model.y: batch_y,
                                                              model.z_1: batch_z_1,
                                                              model.z_0: batch_z_0,
                                                          })

                d_overpowered = d_0_loss < g_0_loss / 2

                # Print loss
                print("[+] Step %08d => " % step,
                      " D loss : {:.8f}".format(d_0_loss),
                      " G loss : {:.8f}".format(g_0_loss))

                # Training G model with sample image and noise
                sample_z_0 = np.random.uniform(-1., 1., [model.sample_num, model.z_dim]).astype(np.float32)
                sample_z_1 = np.random.uniform(-1., 1., [model.sample_num, model.z_dim]).astype(np.float32)
                _, samples = s.run([model.g_1, model.g_0],
                                   feed_dict={
                                       model.y: sample_y,
                                       model.z_1: sample_z_1,
                                       model.z_0: sample_z_0,
                                   })

                samples = np.reshape(samples, [model.batch_size] + model.image_shape)

                # Summary saver
                model.writer.add_summary(summary, step)

                # Export image generated by model G
                sample_image_height = model.sample_size
                sample_image_width = model.sample_size
                sample_dir = results['output'] + 'train_{:08d}.png'.format(step)

                # Generated image save
                iu.save_images(samples,
                               size=[sample_image_height, sample_image_width],
                               image_path=sample_dir)

                # Model save
                model.saver.save(s, results['model'], global_step=step)

    end_time = time.time() - start_time  # Clocking end

    # Elapsed time
    print("[+] Elapsed time {:.8f}s".format(end_time))

    # Close tf.Session
    s.close()


if __name__ == '__main__':
    main()
